# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Differentially private optimizers for TensorFlow."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from privacy.analysis import privacy_ledger
from privacy.dp_query import gaussian_query


def make_optimizer_class(cls):
  """Constructs a DP optimizer class from an existing one."""
  '''根据一个已经存在的优化器构造一个新的满足差分隐私的优化器'''
  if (tf.train.Optimizer.compute_gradients.__code__ is
      not cls.compute_gradients.__code__):
    tf.logging.warning(
        'WARNING: Calling make_optimizer_class() on class %s that overrides '
        'method compute_gradients(). Check to ensure that '
        'make_optimizer_class() does not interfere with overridden version.',
        cls.__name__)

  class DPOptimizerClass(cls):
    """Differentially private subclass of given class cls."""
    '''差分隐私子类'''
    def __init__(
        self,
        dp_average_query,
        num_microbatches,
        unroll_microbatches=False,
        *args,  # pylint: disable=keyword-arg-before-vararg
        **kwargs):
      super(DPOptimizerClass, self).__init__(*args, **kwargs)
      self._dp_average_query = dp_average_query
      self._num_microbatches = num_microbatches
      self._global_state = self._dp_average_query.initial_global_state()
      # TODO(b/122613513): Set unroll_microbatches=True to avoid this bug.
      # Beware: When num_microbatches is large (>100), enabling this parameter
      # may cause an OOM error.
      self._unroll_microbatches = unroll_microbatches

    #计算梯度
    def compute_gradients(self,
                          loss,
                          var_list,
                          gate_gradients=tf.train.Optimizer.GATE_OP,
                          aggregation_method=None,
                          colocate_gradients_with_ops=False,
                          grad_loss=None,
                          gradient_tape=None):
      if callable(loss):
        # TF is running in Eager mode, check we received a vanilla tape.
        if not gradient_tape:
          raise ValueError('When in Eager mode, a tape needs to be passed.')

        vector_loss = loss()
        sample_state = self._dp_average_query.initial_sample_state(
            self._global_state, var_list)
        microbatches_losses = tf.reshape(vector_loss,
                                         [self._num_microbatches, -1])#计算损失
        sample_params = (
            self._dp_average_query.derive_sample_params(self._global_state))

        def process_microbatch(i, sample_state):#在差分隐私的帮助下处理一个小的batch
          """Process one microbatch (record) with privacy helper."""
          microbatch_loss = tf.reduce_mean(tf.gather(microbatches_losses, [i]))
          grads = gradient_tape.gradient(microbatch_loss, var_list)
          sample_state = self._dp_average_query.accumulate_record(sample_params,
                                                                  sample_state,
                                                                  grads)
          return sample_state

        for idx in range(self._num_microbatches):
          sample_state = process_microbatch(idx, sample_state)

        final_grads, self._global_state = (
            self._dp_average_query.get_noised_result(sample_state,
                                                     self._global_state))
            #寻找得到梯度
        grads_and_vars = list(zip(final_grads, var_list))
        return grads_and_vars

      else:
        # TF is running in graph mode, check we did not receive a gradient tape.
        if gradient_tape:
          raise ValueError('When in graph mode, a tape should not be passed.')

        # Note: it would be closhttps://github.com/blyspyder/privacy.giter to the correct i.i.d. sampling of records if
        # we sampled each microbatch from the appropriate binomial distribution,
        # although that still wouldn't be quite correct because it would be
        # sampling from the dataset without replacement.
        microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
        sample_params = (
            self._dp_average_query.derive_sample_params(self._global_state))

        def process_microbatch(i, sample_state):
          """Process one microbatch (record) with privacy helper."""
          grads, _ = zip(*super(cls, self).compute_gradients(
              tf.reduce_mean(tf.gather(microbatches_losses,
                                       [i])), var_list, gate_gradients,
              aggregation_method, colocate_gradients_with_ops, grad_loss))
          grads_list = list(grads)
          sample_state = self._dp_average_query.accumulate_record(
              sample_params, sample_state, grads_list)
          return sample_state

        if var_list is None:
          var_list = (
              tf.trainable_variables() + tf.get_collection(
                  tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))

        sample_state = self._dp_average_query.initial_sample_state(
            self._global_state, var_list)

        if self._unroll_microbatches:
          for idx in range(self._num_microbatches):
            sample_state = process_microbatch(idx, sample_state)
        else:
          # Use of while_loop here requires that sample_state be a nested
          # structure of tensors. In general, we would prefer to allow it to be
          # an arbitrary opaque type.
          cond_fn = lambda i, _: tf.less(i, self._num_microbatches)
          body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)]  # pylint: disable=line-too-long
          idx = tf.constant(0)
          _, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state])

        final_grads, self._global_state = (
            self._dp_average_query.get_noised_result(
                sample_state, self._global_state))

        return list(zip(final_grads, var_list))

  return DPOptimizerClass


def make_gaussian_optimizer_class(cls):
  """Constructs a DP optimizer with Gaussian averaging of updates."""
  '''使用高斯分布构造一个差分隐私优化器'''

  '''高斯分布差分隐私优化器类'''
  class DPGaussianOptimizerClass(make_optimizer_class(cls)):
    """DP subclass of given class cls using Gaussian averaging."""
    '''构造高斯分布差分隐私优化器需要的参数'''
    def __init__(
        self,
        l2_norm_clip,
        noise_multiplier,
        num_microbatches,
        ledger, #样本查询概率
        unroll_microbatches=False,
        *args,  # pylint: disable=keyword-arg-before-vararg
        **kwargs):

      dp_average_query = gaussian_query.GaussianAverageQuery(
          l2_norm_clip, l2_norm_clip * noise_multiplier,
          num_microbatches, ledger)#调用题哪家高斯分布噪声的接口，向其中梯度下降过程中添加高斯分布

      if ledger:
        dp_average_query = privacy_ledger.QueryWithLedger(
            dp_average_query, ledger)

      super(DPGaussianOptimizerClass, self).__init__(
          dp_average_query,
          num_microbatches,
          unroll_microbatches,
          *args,
          **kwargs)

    @property
    def ledger(self):
      return self._ledger

  return DPGaussianOptimizerClass

# Compatibility with tf 1 and 2 APIs
'''tf不同版本的优化器调用
1.使用tf定义未进行加噪之后的优化器
2.将标准的tf优化器传入到dpsgd的类构造器中，返回得到加噪之后的优化器
'''
try:
  AdagradOptimizer = tf.train.AdagradOptimizer
  AdamOptimizer = tf.train.AdamOptimizer
  GradientDescentOptimizer = tf.train.GradientDescentOptimizer
  AdadeltaOptimizer = tf.train.AdadeltaOptimizer#自定义网络
except:  # pylint: disable=bare-except
  AdagradOptimizer = tf.optimizers.Adagrad
  AdamOptimizer = tf.optimizers.Adam
  GradientDescentOptimizer = tf.optimizers.SGD  # pylint: disable=invalid-name

DPAdagradOptimizer = make_optimizer_class(AdagradOptimizer)
DPAdamOptimizer = make_optimizer_class(AdamOptimizer)
DPGradientDescentOptimizer = make_optimizer_class(GradientDescentOptimizer)
DPAdadeltaOptimizer = make_optimizer_class(AdadeltaOptimizer)


DPAdagradGaussianOptimizer = make_gaussian_optimizer_class(AdagradOptimizer)
DPAdamGaussianOptimizer = make_gaussian_optimizer_class(AdamOptimizer)
DPGradientDescentGaussianOptimizer = make_gaussian_optimizer_class(
    GradientDescentOptimizer)
DPAdadeltaGaussianOptimizer = make_gaussian_optimizer_class(AdadeltaOptimizer)

